{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "view-in-github"
   },
   "source": [
    "<a href=\"https://colab.research.google.com/github/NeuromatchAcademy/course-content/blob/master/tutorials/Bonus_UnsupervisedLearning/Tutorial3.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Neuromatch Academy 2020: Week 3, Day 5\n",
    "# Tutorial 3: Autoencoders applications\n",
    "\n",
    "__Content creators:__ Marco Brigham and the [CCNSS](https://www.ccnss.org/) team (2014-2018)\n",
    "\n",
    "__Content reviewers:__ Itzel Olivos, Karen Schroeder, Karolina Stosio, Kshitij Dwivedi, Spiros Chavlis, Michael Waskom"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "# Tutorial Objectives\n",
    "\n",
    "## Autoencoder applications\n",
    "\n",
    "How do autoencoders with rich internal representations perform on the MNIST cognitive task?\n",
    "\n",
    "How do autoencoders perceive unseen digit classes? \n",
    "\n",
    "How does ANN image encoding differ from human vision?\n",
    "\n",
    "We are equipped with tools and techniques to answer these questions, and hopefully, many others you may encounter in your research!\n",
    "\n",
    "&nbsp;\n",
    "\n",
    "![MNIST cognitive task](https://github.com/mpbrigham/colaboratory-figures/raw/master/nma/autoencoders/mnist_task.png)\n",
    "\n",
    "&nbsp;\n",
    "\n",
    "In this tutorial, you will:\n",
    "- Analyze how autoencoders perceive transformed data (added noise, occluded parts, and rotations), and how that evolves with short re-train sessions\n",
    "- Use autoencoders to visualize unseen digit classes\n",
    "- Understand visual encoding for fully connected ANN autoencoders"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "execution": {
     "iopub.execute_input": "2021-05-25T01:05:56.775845Z",
     "iopub.status.busy": "2021-05-25T01:05:56.775311Z",
     "iopub.status.idle": "2021-05-25T01:05:56.826921Z",
     "shell.execute_reply": "2021-05-25T01:05:56.827386Z"
    }
   },
   "outputs": [],
   "source": [
    "# @title Video 1: Applications\n",
    "from IPython.display import YouTubeVideo\n",
    "video = YouTubeVideo(id=\"_bzW_jkH6l0\", width=854, height=480, fs=1)\n",
    "print(\"Video available at https://youtube.com/watch?v=\" + video.id)\n",
    "video"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "# Setup\n",
    "Please execute the cell(s) below to initialize the notebook environment."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "both",
    "execution": {
     "iopub.execute_input": "2021-05-25T01:05:56.831895Z",
     "iopub.status.busy": "2021-05-25T01:05:56.831284Z",
     "iopub.status.idle": "2021-05-25T01:05:57.819710Z",
     "shell.execute_reply": "2021-05-25T01:05:57.818805Z"
    }
   },
   "outputs": [],
   "source": [
    "# Imports\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import os\n",
    "from scipy import ndimage\n",
    "\n",
    "import torch\n",
    "from torch import nn, optim\n",
    "\n",
    "from sklearn.datasets import fetch_openml"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "execution": {
     "iopub.execute_input": "2021-05-25T01:05:57.833176Z",
     "iopub.status.busy": "2021-05-25T01:05:57.832142Z",
     "iopub.status.idle": "2021-05-25T01:05:57.866476Z",
     "shell.execute_reply": "2021-05-25T01:05:57.865366Z"
    }
   },
   "outputs": [],
   "source": [
    "# @title Figure settings\n",
    "\n",
    "%config InlineBackend.figure_format = 'retina'\n",
    "plt.style.use(\"https://raw.githubusercontent.com/NeuromatchAcademy/course-content/master/nma.mplstyle\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "execution": {
     "iopub.execute_input": "2021-05-25T01:05:57.925614Z",
     "iopub.status.busy": "2021-05-25T01:05:57.908717Z",
     "iopub.status.idle": "2021-05-25T01:05:57.939065Z",
     "shell.execute_reply": "2021-05-25T01:05:57.939591Z"
    }
   },
   "outputs": [],
   "source": [
    "# @title Helper functions\n",
    "\n",
    "\n",
    "def downloadMNIST():\n",
    "  \"\"\"\n",
    "  Download MNIST dataset and transform it to torch.Tensor\n",
    "\n",
    "  Args:\n",
    "    None\n",
    "\n",
    "  Returns:\n",
    "    x_train : training images (torch.Tensor) (60000, 28, 28)\n",
    "    x_test  : test images (torch.Tensor) (10000, 28, 28)\n",
    "    y_train : training labels (torch.Tensor) (60000, )\n",
    "    y_train : test labels (torch.Tensor) (10000, )\n",
    "  \"\"\"\n",
    "  X, y = fetch_openml('mnist_784', version=1, return_X_y=True, as_frame = False)\n",
    "  # Trunk the data\n",
    "  n_train = 60000\n",
    "  n_test = 10000\n",
    "\n",
    "  train_idx = np.arange(0, n_train)\n",
    "  test_idx = np.arange(n_train, n_train + n_test)\n",
    "\n",
    "  x_train, y_train = X[train_idx], y[train_idx]\n",
    "  x_test, y_test = X[test_idx], y[test_idx]\n",
    "\n",
    "  # Transform np.ndarrays to torch.Tensor\n",
    "  x_train = torch.from_numpy(np.reshape(x_train,\n",
    "                                        (len(x_train),\n",
    "                                         28, 28)).astype(np.float32))\n",
    "  x_test = torch.from_numpy(np.reshape(x_test,\n",
    "                                       (len(x_test),\n",
    "                                        28, 28)).astype(np.float32))\n",
    "\n",
    "  y_train = torch.from_numpy(y_train.astype(int))\n",
    "  y_test = torch.from_numpy(y_test.astype(int))\n",
    "\n",
    "  return (x_train, y_train, x_test, y_test)\n",
    "\n",
    "\n",
    "def init_weights_kaiming_uniform(layer):\n",
    "  \"\"\"\n",
    "  Initializes weights from linear PyTorch layer\n",
    "  with kaiming uniform distribution.\n",
    "\n",
    "  Args:\n",
    "    layer (torch.Module)\n",
    "        Pytorch layer\n",
    "\n",
    "  Returns:\n",
    "    Nothing.\n",
    "  \"\"\"\n",
    "  # check for linear PyTorch layer\n",
    "  if isinstance(layer, nn.Linear):\n",
    "    # initialize weights with kaiming uniform distribution\n",
    "    nn.init.kaiming_uniform_(layer.weight.data)\n",
    "\n",
    "\n",
    "def init_weights_kaiming_normal(layer):\n",
    "  \"\"\"\n",
    "  Initializes weights from linear PyTorch layer\n",
    "  with kaiming normal distribution.\n",
    "\n",
    "  Args:\n",
    "    layer (torch.Module)\n",
    "        Pytorch layer\n",
    "\n",
    "  Returns:\n",
    "    Nothing.\n",
    "  \"\"\"\n",
    "  # check for linear PyTorch layer\n",
    "  if isinstance(layer, nn.Linear):\n",
    "    # initialize weights with kaiming normal distribution\n",
    "    nn.init.kaiming_normal_(layer.weight.data)\n",
    "\n",
    "\n",
    "def get_layer_weights(layer):\n",
    "  \"\"\"\n",
    "  Retrieves learnable parameters from PyTorch layer.\n",
    "\n",
    "  Args:\n",
    "    layer (torch.Module)\n",
    "        Pytorch layer\n",
    "\n",
    "  Returns:\n",
    "    list with learnable parameters\n",
    "  \"\"\"\n",
    "  # initialize output list\n",
    "  weights = []\n",
    "\n",
    "  # check whether layer has learnable parameters\n",
    "  if layer.parameters():\n",
    "    # copy numpy array representation of each set of learnable parameters\n",
    "    for item in layer.parameters():\n",
    "      weights.append(item.detach().numpy())\n",
    "\n",
    "  return weights\n",
    "\n",
    "\n",
    "def eval_mse(y_pred, y_true):\n",
    "  \"\"\"\n",
    "  Evaluates mean square error (MSE) between y_pred and y_true\n",
    "\n",
    "  Args:\n",
    "    y_pred (torch.Tensor)\n",
    "        prediction samples\n",
    "\n",
    "    v (numpy array of floats)\n",
    "        ground truth samples\n",
    "\n",
    "  Returns:\n",
    "    MSE(y_pred, y_true)\n",
    "  \"\"\"\n",
    "\n",
    "  with torch.no_grad():\n",
    "    criterion = nn.MSELoss()\n",
    "    loss = criterion(y_pred, y_true)\n",
    "\n",
    "  return float(loss)\n",
    "\n",
    "\n",
    "def eval_bce(y_pred, y_true):\n",
    "  \"\"\"\n",
    "  Evaluates binary cross-entropy (BCE) between y_pred and y_true\n",
    "\n",
    "  Args:\n",
    "    y_pred (torch.Tensor)\n",
    "        prediction samples\n",
    "\n",
    "    v (numpy array of floats)\n",
    "        ground truth samples\n",
    "\n",
    "  Returns:\n",
    "    BCE(y_pred, y_true)\n",
    "  \"\"\"\n",
    "\n",
    "  with torch.no_grad():\n",
    "    criterion = nn.BCELoss()\n",
    "    loss = criterion(y_pred, y_true)\n",
    "\n",
    "  return float(loss)\n",
    "\n",
    "\n",
    "def plot_row(images, show_n=10, image_shape=None):\n",
    "  \"\"\"\n",
    "  Plots rows of images from list of iterables (iterables: list, numpy array\n",
    "  or torch.Tensor). Also accepts single iterable.\n",
    "  Randomly selects images in each list element if item count > show_n.\n",
    "\n",
    "  Args:\n",
    "    images (iterable or list of iterables)\n",
    "        single iterable with images, or list of iterables\n",
    "\n",
    "    show_n (integer)\n",
    "        maximum number of images per row\n",
    "\n",
    "    image_shape (tuple or list)\n",
    "        original shape of image if vectorized form\n",
    "\n",
    "  Returns:\n",
    "    Nothing.\n",
    "  \"\"\"\n",
    "\n",
    "  if not isinstance(images, (list, tuple)):\n",
    "    images = [images]\n",
    "\n",
    "  for items_idx, items in enumerate(images):\n",
    "\n",
    "    items = np.array(items)\n",
    "    if items.ndim == 1:\n",
    "      items = np.expand_dims(items, axis=0)\n",
    "\n",
    "    if len(items) > show_n:\n",
    "      selected = np.random.choice(len(items), show_n, replace=False)\n",
    "      items = items[selected]\n",
    "\n",
    "    if image_shape is not None:\n",
    "      items = items.reshape([-1] + list(image_shape))\n",
    "\n",
    "    plt.figure(figsize=(len(items) * 1.5, 2))\n",
    "    for image_idx, image in enumerate(items):\n",
    "\n",
    "      plt.subplot(1, len(items), image_idx + 1)\n",
    "      plt.imshow(image, cmap='gray', vmin=image.min(), vmax=image.max())\n",
    "      plt.axis('off')\n",
    "\n",
    "    plt.tight_layout()\n",
    "\n",
    "\n",
    "def to_s2(u):\n",
    "  \"\"\"\n",
    "  Projects 3D coordinates to spherical coordinates (theta, phi) surface of\n",
    "  unit sphere S2.\n",
    "  theta: [0, pi]\n",
    "  phi: [-pi, pi]\n",
    "\n",
    "  Args:\n",
    "    u (list, numpy array or torch.Tensor of floats)\n",
    "        3D coordinates\n",
    "\n",
    "  Returns:\n",
    "    Sperical coordinates (theta, phi) on surface of unit sphere S2.\n",
    "  \"\"\"\n",
    "\n",
    "  x, y, z = (u[:, 0], u[:, 1], u[:, 2])\n",
    "  r = np.sqrt(x**2 + y**2 + z**2)\n",
    "  theta = np.arccos(z / r)\n",
    "  phi = np.arctan2(x, y)\n",
    "\n",
    "  return np.array([theta, phi]).T\n",
    "\n",
    "\n",
    "def to_u3(s):\n",
    "  \"\"\"\n",
    "  Converts from 2D coordinates on surface of unit sphere S2 to 3D coordinates\n",
    "  (on surface of S2), i.e. (theta, phi) ---> (1, theta, phi).\n",
    "\n",
    "  Args:\n",
    "    s (list, numpy array or torch.Tensor of floats)\n",
    "        2D coordinates on unit sphere S_2\n",
    "\n",
    "  Returns:\n",
    "    3D coordinates on surface of unit sphere S_2\n",
    "  \"\"\"\n",
    "\n",
    "  theta, phi = (s[:, 0], s[:, 1])\n",
    "  x = np.sin(theta) * np.sin(phi)\n",
    "  y = np.sin(theta) * np.cos(phi)\n",
    "  z = np.cos(theta)\n",
    "\n",
    "  return np.array([x, y, z]).T\n",
    "\n",
    "\n",
    "def xy_lim(x):\n",
    "  \"\"\"\n",
    "  Return arguments for plt.xlim and plt.ylim calculated from minimum\n",
    "  and maximum of x.\n",
    "\n",
    "  Args:\n",
    "    x (list, numpy array or torch.Tensor of floats)\n",
    "        data to be plotted\n",
    "\n",
    "  Returns:\n",
    "    Nothing.\n",
    "  \"\"\"\n",
    "  x_min = np.min(x, axis=0)\n",
    "  x_max = np.max(x, axis=0)\n",
    "\n",
    "  x_min = x_min - np.abs(x_max - x_min) * 0.05 - np.finfo(float).eps\n",
    "  x_max = x_max + np.abs(x_max - x_min) * 0.05 + np.finfo(float).eps\n",
    "\n",
    "  return [x_min[0], x_max[0]], [x_min[1], x_max[1]]\n",
    "\n",
    "\n",
    "def plot_generative(x, decoder_fn, image_shape, n_row=16, s2=False):\n",
    "  \"\"\"\n",
    "  Plots images reconstructed by decoder_fn from a 2D grid in\n",
    "  latent space that is determined by minimum and maximum values in x.\n",
    "\n",
    "  Args:\n",
    "    x (list, numpy array or torch.Tensor of floats)\n",
    "        2D or 3D coordinates in latent space\n",
    "\n",
    "    decoder_fn (integer)\n",
    "        function returning vectorized images from 2D latent space coordinates\n",
    "\n",
    "    image_shape (tuple or list)\n",
    "        original shape of image\n",
    "\n",
    "    n_row (integer)\n",
    "        number of rows in grid\n",
    "\n",
    "    s2 (boolean)\n",
    "        convert 3D coordinates (x, y, z) to spherical coordinates (theta, phi)\n",
    "\n",
    "  Returns:\n",
    "    Nothing.\n",
    "  \"\"\"\n",
    "\n",
    "  if s2:\n",
    "    x = to_s2(np.array(x))\n",
    "\n",
    "  xlim, ylim = xy_lim(np.array(x))\n",
    "\n",
    "  dx = (xlim[1] - xlim[0]) / n_row\n",
    "  grid = [np.linspace(ylim[0] + dx / 2, ylim[1] - dx / 2, n_row),\n",
    "          np.linspace(xlim[0] + dx / 2, xlim[1] - dx / 2, n_row)]\n",
    "\n",
    "  canvas = np.zeros((image_shape[0] * n_row, image_shape[1] * n_row))\n",
    "\n",
    "  cmap = plt.get_cmap('gray')\n",
    "\n",
    "  for j, latent_y in enumerate(grid[0][::-1]):\n",
    "    for i, latent_x in enumerate(grid[1]):\n",
    "\n",
    "      latent = np.array([[latent_x, latent_y]], dtype=np.float32)\n",
    "\n",
    "      if s2:\n",
    "        latent = to_u3(latent)\n",
    "\n",
    "      with torch.no_grad():\n",
    "        x_decoded = decoder_fn(torch.from_numpy(latent))\n",
    "\n",
    "      x_decoded = x_decoded.reshape(image_shape)\n",
    "\n",
    "      canvas[j * image_shape[0]: (j + 1) * image_shape[0],\n",
    "             i * image_shape[1]: (i + 1) * image_shape[1]] = x_decoded\n",
    "\n",
    "  plt.imshow(canvas, cmap=cmap, vmin=canvas.min(), vmax=canvas.max())\n",
    "  plt.axis('off')\n",
    "\n",
    "\n",
    "def plot_latent(x, y, show_n=500, s2=False, fontdict=None, xy_labels=None):\n",
    "  \"\"\"\n",
    "  Plots digit class of each sample in 2D latent space coordinates.\n",
    "\n",
    "  Args:\n",
    "    x (list, numpy array or torch.Tensor of floats)\n",
    "        2D coordinates in latent space\n",
    "\n",
    "    y (list, numpy array or torch.Tensor of floats)\n",
    "        digit class of each sample\n",
    "\n",
    "    n_row (integer)\n",
    "        number of samples\n",
    "\n",
    "    s2 (boolean)\n",
    "        convert 3D coordinates (x, y, z) to spherical coordinates (theta, phi)\n",
    "\n",
    "    fontdict (dictionary)\n",
    "        style option for plt.text\n",
    "\n",
    "    xy_labels (list)\n",
    "        optional list with [xlabel, ylabel]\n",
    "\n",
    "  Returns:\n",
    "    Nothing.\n",
    "  \"\"\"\n",
    "\n",
    "  if fontdict is None:\n",
    "    fontdict = {'weight': 'bold', 'size': 12}\n",
    "\n",
    "  if s2:\n",
    "    x = to_s2(np.array(x))\n",
    "\n",
    "  cmap = plt.get_cmap('tab10')\n",
    "\n",
    "  if len(x) > show_n:\n",
    "    selected = np.random.choice(len(x), show_n, replace=False)\n",
    "    x = x[selected]\n",
    "    y = y[selected]\n",
    "\n",
    "  for my_x, my_y in zip(x, y):\n",
    "    plt.text(my_x[0], my_x[1], str(int(my_y)),\n",
    "             color=cmap(int(my_y) / 10.),\n",
    "             fontdict=fontdict,\n",
    "             horizontalalignment='center',\n",
    "             verticalalignment='center',\n",
    "             alpha=0.8)\n",
    "\n",
    "  xlim, ylim = xy_lim(np.array(x))\n",
    "  plt.xlim(xlim)\n",
    "  plt.ylim(ylim)\n",
    "\n",
    "  if s2:\n",
    "    if xy_labels is None:\n",
    "      xy_labels = [r'$\\varphi$', r'$\\theta$']\n",
    "\n",
    "    plt.xticks(np.arange(0, np.pi + np.pi / 6, np.pi / 6),\n",
    "               ['0', '$\\pi/6$', '$\\pi/3$', '$\\pi/2$',\n",
    "                '$2\\pi/3$', '$5\\pi/6$', '$\\pi$'])\n",
    "    plt.yticks(np.arange(-np.pi, np.pi + np.pi / 3, np.pi / 3),\n",
    "               ['$-\\pi$', '$-2\\pi/3$', '$-\\pi/3$', '0',\n",
    "                '$\\pi/3$', '$2\\pi/3$', '$\\pi$'])\n",
    "\n",
    "  if xy_labels is None:\n",
    "    xy_labels = ['$Z_1$', '$Z_2$']\n",
    "\n",
    "  plt.xlabel(xy_labels[0])\n",
    "  plt.ylabel(xy_labels[1])\n",
    "\n",
    "\n",
    "def plot_latent_generative(x, y, decoder_fn, image_shape, s2=False,\n",
    "                           title=None, xy_labels=None):\n",
    "  \"\"\"\n",
    "  Two horizontal subplots generated with encoder map and decoder grid.\n",
    "\n",
    "  Args:\n",
    "    x (list, numpy array or torch.Tensor of floats)\n",
    "        2D coordinates in latent space\n",
    "\n",
    "    y (list, numpy array or torch.Tensor of floats)\n",
    "        digit class of each sample\n",
    "\n",
    "    decoder_fn (integer)\n",
    "        function returning vectorized images from 2D latent space coordinates\n",
    "\n",
    "    image_shape (tuple or list)\n",
    "        original shape of image\n",
    "\n",
    "    s2 (boolean)\n",
    "        convert 3D coordinates (x, y, z) to spherical coordinates (theta, phi)\n",
    "\n",
    "    title (string)\n",
    "        plot title\n",
    "\n",
    "    xy_labels (list)\n",
    "        optional list with [xlabel, ylabel]\n",
    "\n",
    "  Returns:\n",
    "    Nothing.\n",
    "  \"\"\"\n",
    "\n",
    "  fig = plt.figure(figsize=(12, 6))\n",
    "\n",
    "  if title is not None:\n",
    "    fig.suptitle(title, y=1.05)\n",
    "\n",
    "  ax = fig.add_subplot(121)\n",
    "  ax.set_title('Encoder map', y=1.05)\n",
    "  plot_latent(x, y, s2=s2, xy_labels=xy_labels)\n",
    "\n",
    "  ax = fig.add_subplot(122)\n",
    "  ax.set_title('Decoder grid', y=1.05)\n",
    "  plot_generative(x, decoder_fn, image_shape, s2=s2)\n",
    "\n",
    "  plt.tight_layout()\n",
    "  plt.show()\n",
    "\n",
    "\n",
    "def plot_latent_ab(x1, x2, y, selected_idx=None,\n",
    "                   title_a='Before', title_b='After', show_n=500, s2=False):\n",
    "  \"\"\"\n",
    "  Two horizontal subplots with encoder maps.\n",
    "\n",
    "  Args:\n",
    "    x1 (list, numpy array or torch.Tensor of floats)\n",
    "        2D coordinates in latent space (left plot)\n",
    "\n",
    "    x2 (list, numpy array or torch.Tensor of floats)\n",
    "        digit class of each sample (right plot)\n",
    "\n",
    "    y (list, numpy array or torch.Tensor of floats)\n",
    "        digit class of each sample\n",
    "\n",
    "    selected_idx (list of integers)\n",
    "        indexes of elements to be plotted\n",
    "\n",
    "    show_n (integer)\n",
    "        maximum number of samples in each plot\n",
    "\n",
    "    s2 (boolean)\n",
    "        convert 3D coordinates (x, y, z) to spherical coordinates (theta, phi)\n",
    "\n",
    "  Returns:\n",
    "    Nothing.\n",
    "  \"\"\"\n",
    "\n",
    "  fontdict = {'weight': 'bold', 'size': 12}\n",
    "\n",
    "  if len(x1) > show_n:\n",
    "\n",
    "    if selected_idx is None:\n",
    "      selected_idx = np.random.choice(len(x1), show_n, replace=False)\n",
    "\n",
    "    x1 = x1[selected_idx]\n",
    "    x2 = x2[selected_idx]\n",
    "    y = y[selected_idx]\n",
    "\n",
    "  data = np.concatenate([x1, x2])\n",
    "\n",
    "  if s2:\n",
    "    xlim, ylim = xy_lim(to_s2(data))\n",
    "\n",
    "  else:\n",
    "    xlim, ylim = xy_lim(data)\n",
    "\n",
    "  plt.figure(figsize=(12, 6))\n",
    "\n",
    "  ax = plt.subplot(121)\n",
    "  ax.set_title(title_a, y=1.05)\n",
    "  plot_latent(x1, y, fontdict=fontdict, s2=s2)\n",
    "  plt.xlim(xlim)\n",
    "  plt.ylim(ylim)\n",
    "\n",
    "  ax = plt.subplot(122)\n",
    "  ax.set_title(title_b, y=1.05)\n",
    "  plot_latent(x2, y, fontdict=fontdict, s2=s2)\n",
    "  plt.xlim(xlim)\n",
    "  plt.ylim(ylim)\n",
    "  plt.tight_layout()\n",
    "\n",
    "\n",
    "def runSGD(net, input_train, input_test, out_train=None, out_test=None,\n",
    "           optimizer=None, criterion='bce', n_epochs=10, batch_size=32,\n",
    "           verbose=False):\n",
    "  \"\"\"\n",
    "  Trains autoencoder network with stochastic gradient descent with\n",
    "  optimizer and loss criterion. Train samples are shuffled, and loss is\n",
    "  displayed at the end of each opoch for both MSE and BCE. Plots training loss\n",
    "  at each minibatch (maximum of 500 randomly selected values).\n",
    "\n",
    "  Args:\n",
    "    net (torch network)\n",
    "        ANN network (nn.Module)\n",
    "\n",
    "    input_train (torch.Tensor)\n",
    "        vectorized input images from train set\n",
    "\n",
    "    input_test (torch.Tensor)\n",
    "        vectorized input images from test set\n",
    "\n",
    "    criterion (string)\n",
    "        train loss: 'bce' or 'mse'\n",
    "\n",
    "    out_train (torch.Tensor)\n",
    "        optional target images from train set\n",
    "\n",
    "    out_test (torch.Tensor)\n",
    "        optional target images from test set\n",
    "\n",
    "    optimizer (torch optimizer)\n",
    "        optional target images from train set\n",
    "\n",
    "    criterion (string)\n",
    "        train loss: 'bce' or 'mse'\n",
    "\n",
    "    n_epochs (boolean)\n",
    "        number of full iterations of training data\n",
    "\n",
    "    batch_size (integer)\n",
    "        number of element in mini-batches\n",
    "\n",
    "    verbose (boolean)\n",
    "        whether to print final loss\n",
    "\n",
    "  Returns:\n",
    "    Nothing.\n",
    "  \"\"\"\n",
    "\n",
    "  if out_train is not None and out_test is not None:\n",
    "    different_output = True\n",
    "  else:\n",
    "    different_output = False\n",
    "\n",
    "  # Initialize loss function\n",
    "  if criterion == 'mse':\n",
    "    loss_fn = nn.MSELoss()\n",
    "  elif criterion == 'bce':\n",
    "    loss_fn = nn.BCELoss()\n",
    "  else:\n",
    "    print('Please specify either \"mse\" or \"bce\" for loss criterion')\n",
    "\n",
    "  # Initialize SGD optimizer\n",
    "  if optimizer is None:\n",
    "    optimizer = optim.Adam(net.parameters())\n",
    "\n",
    "  # Placeholder for loss\n",
    "  track_loss = []\n",
    "\n",
    "  print('Epoch', '\\t', 'Loss train', '\\t', 'Loss test')\n",
    "  for i in range(n_epochs):\n",
    "\n",
    "    shuffle_idx = np.random.permutation(len(input_train))\n",
    "    batches = torch.split(input_train[shuffle_idx], batch_size)\n",
    "\n",
    "    if different_output:\n",
    "      batches_out = torch.split(out_train[shuffle_idx], batch_size)\n",
    "\n",
    "    for batch_idx, batch in enumerate(batches):\n",
    "\n",
    "      output_train = net(batch)\n",
    "\n",
    "      if different_output:\n",
    "        loss = loss_fn(output_train, batches_out[batch_idx])\n",
    "      else:\n",
    "        loss = loss_fn(output_train, batch)\n",
    "\n",
    "      optimizer.zero_grad()\n",
    "      loss.backward()\n",
    "      optimizer.step()\n",
    "\n",
    "      # Keep track of loss at each epoch\n",
    "      track_loss += [float(loss)]\n",
    "\n",
    "    loss_epoch = f'{i+1}/{n_epochs}'\n",
    "    with torch.no_grad():\n",
    "      output_train = net(input_train)\n",
    "      if different_output:\n",
    "        loss_train = loss_fn(output_train, out_train)\n",
    "      else:\n",
    "        loss_train = loss_fn(output_train, input_train)\n",
    "\n",
    "      loss_epoch += f'\\t {loss_train:.4f}'\n",
    "\n",
    "      output_test = net(input_test)\n",
    "      if different_output:\n",
    "        loss_test = loss_fn(output_test, out_test)\n",
    "      else:\n",
    "        loss_test = loss_fn(output_test, input_test)\n",
    "\n",
    "      loss_epoch += f'\\t\\t {loss_test:.4f}'\n",
    "\n",
    "    print(loss_epoch)\n",
    "\n",
    "  if verbose:\n",
    "    # Print loss\n",
    "    if different_output:\n",
    "      loss_mse = f'\\nMSE\\t {eval_mse(output_train, out_train):0.4f}'\n",
    "      loss_mse += f'\\t\\t {eval_mse(output_test, out_test):0.4f}'\n",
    "    else:\n",
    "      loss_mse = f'\\nMSE\\t {eval_mse(output_train, input_train):0.4f}'\n",
    "      loss_mse += f'\\t\\t {eval_mse(output_test, input_test):0.4f}'\n",
    "    print(loss_mse)\n",
    "\n",
    "    if different_output:\n",
    "      loss_bce = f'BCE\\t {eval_bce(output_train, out_train):0.4f}'\n",
    "      loss_bce += f'\\t\\t {eval_bce(output_test, out_test):0.4f}'\n",
    "    else:\n",
    "      loss_bce = f'BCE\\t {eval_bce(output_train, input_train):0.4f}'\n",
    "      loss_bce += f'\\t\\t {eval_bce(output_test, input_test):0.4f}'\n",
    "    print(loss_bce)\n",
    "\n",
    "  # Plot loss\n",
    "  step = int(np.ceil(len(track_loss)/500))\n",
    "  x_range = np.arange(0, len(track_loss), step)\n",
    "  plt.figure()\n",
    "  plt.plot(x_range, track_loss[::step], 'C0')\n",
    "  plt.xlabel('Iterations')\n",
    "  plt.ylabel('Loss')\n",
    "  plt.xlim([0, None])\n",
    "  plt.ylim([0, None])\n",
    "  plt.show()\n",
    "\n",
    "\n",
    "def image_occlusion(x, image_shape):\n",
    "  \"\"\"\n",
    "  Randomly selects on quadrant of images and sets to zeros.\n",
    "\n",
    "  Args:\n",
    "    x (torch.Tensor of floats)\n",
    "        vectorized images\n",
    "\n",
    "    image_shape (tuple or list)\n",
    "        original shape of image\n",
    "\n",
    "  Returns:\n",
    "    torch.Tensor.\n",
    "  \"\"\"\n",
    "\n",
    "  selection = np.random.choice(4, len(x))\n",
    "\n",
    "  my_x = np.array(x).copy()\n",
    "  my_x = my_x.reshape(-1, image_shape[0], image_shape[1])\n",
    "\n",
    "  my_x[selection == 0, :int(image_shape[0] / 2), :int(image_shape[1] / 2)] = 0\n",
    "  my_x[selection == 1, int(image_shape[0] / 2):, :int(image_shape[1] / 2)] = 0\n",
    "  my_x[selection == 2, :int(image_shape[0] / 2), int(image_shape[1] / 2):] = 0\n",
    "  my_x[selection == 3, int(image_shape[0] / 2):, int(image_shape[1] / 2):] = 0\n",
    "\n",
    "  my_x = my_x.reshape(x.shape)\n",
    "\n",
    "  return torch.from_numpy(my_x)\n",
    "\n",
    "\n",
    "def image_rotation(x, deg, image_shape):\n",
    "  \"\"\"\n",
    "  Randomly rotates images by +- deg degrees.\n",
    "\n",
    "  Args:\n",
    "    x (torch.Tensor of floats)\n",
    "        vectorized images\n",
    "\n",
    "    deg (integer)\n",
    "        rotation range\n",
    "\n",
    "    image_shape (tuple or list)\n",
    "        original shape of image\n",
    "\n",
    "  Returns:\n",
    "    torch.Tensor.\n",
    "  \"\"\"\n",
    "\n",
    "  my_x = np.array(x).copy()\n",
    "  my_x = my_x.reshape(-1, image_shape[0], image_shape[1])\n",
    "\n",
    "  for idx, item in enumerate(my_x):\n",
    "    my_deg = deg * 2 * np.random.random() - deg\n",
    "    my_x[idx] = ndimage.rotate(my_x[idx], my_deg,\n",
    "                               reshape=False, prefilter=False)\n",
    "\n",
    "  my_x = my_x.reshape(x.shape)\n",
    "\n",
    "  return torch.from_numpy(my_x)\n",
    "\n",
    "\n",
    "class AutoencoderClass(nn.Module):\n",
    "  \"\"\"\n",
    "  Deep autoencoder network object (nn.Module) with optional L2 normalization\n",
    "  of activations in bottleneck layer.\n",
    "\n",
    "  Args:\n",
    "    input_size (integer)\n",
    "        size of input samples\n",
    "\n",
    "    s2 (boolean)\n",
    "        whether to L2 normalize activatinos in bottleneck layer\n",
    "\n",
    "  Returns:\n",
    "    Autoencoder object inherited from nn.Module class.\n",
    "  \"\"\"\n",
    "\n",
    "  def __init__(self, input_size=784, s2=False):\n",
    "\n",
    "    super().__init__()\n",
    "\n",
    "    self.input_size = input_size\n",
    "    self.s2 = s2\n",
    "\n",
    "    if s2:\n",
    "      self.encoding_size = 3\n",
    "\n",
    "    else:\n",
    "      self.encoding_size = 2\n",
    "\n",
    "    self.enc1 = nn.Linear(self.input_size, int(self.input_size / 2))\n",
    "    self.enc1_f = nn.PReLU()\n",
    "    self.enc2 = nn.Linear(int(self.input_size / 2), self.encoding_size * 32)\n",
    "    self.enc2_f = nn.PReLU()\n",
    "    self.enc3 = nn.Linear(self.encoding_size * 32, self.encoding_size)\n",
    "    self.enc3_f = nn.PReLU()\n",
    "    self.dec1 = nn.Linear(self.encoding_size, self.encoding_size * 32)\n",
    "    self.dec1_f = nn.PReLU()\n",
    "    self.dec2 = nn.Linear(self.encoding_size * 32, int(self.input_size / 2))\n",
    "    self.dec2_f = nn.PReLU()\n",
    "    self.dec3 = nn.Linear(int(self.input_size / 2), self.input_size)\n",
    "    self.dec3_f = nn.Sigmoid()\n",
    "\n",
    "  def encoder(self, x):\n",
    "    \"\"\"\n",
    "    Encoder component.\n",
    "    \"\"\"\n",
    "    x = self.enc1_f(self.enc1(x))\n",
    "    x = self.enc2_f(self.enc2(x))\n",
    "    x = self.enc3_f(self.enc3(x))\n",
    "\n",
    "    if self.s2:\n",
    "        x = nn.functional.normalize(x, p=2, dim=1)\n",
    "\n",
    "    return x\n",
    "\n",
    "  def decoder(self, x):\n",
    "    \"\"\"\n",
    "    Decoder component.\n",
    "    \"\"\"\n",
    "    x = self.dec1_f(self.dec1(x))\n",
    "    x = self.dec2_f(self.dec2(x))\n",
    "    x = self.dec3_f(self.dec3(x))\n",
    "\n",
    "    return x\n",
    "\n",
    "  def forward(self, x):\n",
    "    \"\"\"\n",
    "    Forward pass.\n",
    "    \"\"\"\n",
    "    x = self.encoder(x)\n",
    "    x = self.decoder(x)\n",
    "\n",
    "    return x\n",
    "\n",
    "\n",
    "def save_checkpoint(net, optimizer, filename):\n",
    "  \"\"\"\n",
    "  Saves a PyTorch checkpoint.\n",
    "\n",
    "  Args:\n",
    "    net (torch network)\n",
    "        ANN network (nn.Module)\n",
    "\n",
    "    optimizer (torch optimizer)\n",
    "        optimizer for SGD\n",
    "\n",
    "    filename (string)\n",
    "        filename (without extension)\n",
    "\n",
    "  Returns:\n",
    "    Nothing.\n",
    "  \"\"\"\n",
    "\n",
    "  torch.save({'model_state_dict': net.state_dict(),\n",
    "              'optimizer_state_dict': optimizer.state_dict()},\n",
    "             filename+'.pt')\n",
    "\n",
    "\n",
    "def load_checkpoint(url, filename):\n",
    "  \"\"\"\n",
    "  Loads a PyTorch checkpoint from URL is local file not present.\n",
    "\n",
    "  Args:\n",
    "    url (string)\n",
    "        URL location of PyTorch checkpoint\n",
    "\n",
    "    filename (string)\n",
    "        filename (without extension)\n",
    "\n",
    "  Returns:\n",
    "    PyTorch checkpoint of saved model.\n",
    "  \"\"\"\n",
    "\n",
    "  if not os.path.isfile(filename+'.pt'):\n",
    "    os.system(f\"wget {url}.pt\")\n",
    "\n",
    "  return torch.load(filename+'.pt')\n",
    "\n",
    "\n",
    "def reset_checkpoint(net, optimizer, checkpoint):\n",
    "  \"\"\"\n",
    "  Resets PyTorch model to checkpoint.\n",
    "\n",
    "  Args:\n",
    "    net (torch network)\n",
    "        ANN network (nn.Module)\n",
    "\n",
    "    optimizer (torch optimizer)\n",
    "        optimizer for SGD\n",
    "\n",
    "    checkpoint (torch checkpoint)\n",
    "        checkpoint of saved model\n",
    "\n",
    "  Returns:\n",
    "    Nothing.\n",
    "  \"\"\"\n",
    "\n",
    "  net.load_state_dict(checkpoint['model_state_dict'])\n",
    "  optimizer.load_state_dict(checkpoint['optimizer_state_dict'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "# Section 1: Download  and prepare MNIST dataset\n",
    "We use the helper function `downloadMNIST` to download the dataset and transform it into `torch.Tensor` and assign train and test sets to (`x_train`, `y_train`) and (`x_test`, `y_test`).\n",
    "\n",
    "The variable `input_size` stores the length of *vectorized* versions of the images `input_train` and `input_test` for training and test images.\n",
    "\n",
    "**Instructions:**\n",
    "* Please execute the cell below"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "code",
    "execution": {
     "iopub.execute_input": "2021-05-25T01:05:57.946321Z",
     "iopub.status.busy": "2021-05-25T01:05:57.945683Z",
     "iopub.status.idle": "2021-05-25T01:06:20.633036Z",
     "shell.execute_reply": "2021-05-25T01:06:20.633824Z"
    }
   },
   "outputs": [],
   "source": [
    "# Download MNIST\n",
    "x_train, y_train, x_test, y_test = downloadMNIST()\n",
    "\n",
    "x_train = x_train / 255\n",
    "x_test = x_test / 255\n",
    "\n",
    "image_shape = x_train.shape[1:]\n",
    "\n",
    "input_size = np.prod(image_shape)\n",
    "\n",
    "input_train = x_train.reshape([-1, input_size])\n",
    "input_test = x_test.reshape([-1, input_size])\n",
    "\n",
    "test_selected_idx = np.random.choice(len(x_test), 10, replace=False)\n",
    "train_selected_idx = np.random.choice(len(x_train), 10, replace=False)\n",
    "\n",
    "test_subset_idx = np.random.choice(len(x_test), 500, replace=False)\n",
    "\n",
    "print(f'shape image \\t\\t {image_shape}')\n",
    "print(f'shape input_train \\t {input_train.shape}')\n",
    "print(f'shape input_test \\t {input_test.shape}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "# Section 2: Download a pre-trained model\n",
    "The class `AutoencoderClass` implements the autoencoder architectures introduced in the previous tutorial. The design of this class follows the object-oriented programming (OOP) style from tutorial W3D4. Setting the boolean parameter `s2=True` specifies the model with projection onto the $S_2$ sphere.\n",
    "\n",
    "We trained both models for `n_epochs=25` and saved the weights to avoid a lengthy initial training period - these will be our reference model states.\n",
    "\n",
    "Experiments are run from the identical initial conditions by resetting the autoencoder to the reference state at the beginning of each exercise. \n",
    "\n",
    "The mechanism for loading and storing models from PyTorch is the following:\n",
    "```\n",
    "model = nn.Sequential(...)\n",
    "or\n",
    "model = AutoencoderClass()\n",
    "\n",
    "torch.save({'model_state_dict': model.state_dict(),\n",
    "            'optimizer_state_dict': optimizer.state_dict()},\n",
    "           filename_path)\n",
    "\n",
    "checkpoint = torch.load(filename_path)\n",
    "\n",
    "model.load_state_dict(checkpoint['model_state_dict'])\n",
    "optimizer.load_state_dict(checkpoint['optimizer_state_dict'])\n",
    "```\n",
    "See additional [PyTorch instructions](https://pytorch.org/tutorials/recipes/recipes/saving_and_loading_a_general_checkpoint.html), and when to use `model.eval()` and `model.train()` for more complex models.\n",
    "\n",
    "We provide the functions `save_checkpoint`, `load_checkpoint`, and `reset_checkpoint` to implement the steps above and download pre-trained weights from the GitHub repo.\n",
    "\n",
    "If downloading from GitHub fails, please uncomment the 3rd cell bellow to train the model for `n_epochs=10` and save it locally.\n",
    "\n",
    "**Instructions:**\n",
    "* Please execute the cell(s) below"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-25T01:06:20.638495Z",
     "iopub.status.busy": "2021-05-25T01:06:20.637906Z",
     "iopub.status.idle": "2021-05-25T01:06:20.641114Z",
     "shell.execute_reply": "2021-05-25T01:06:20.641583Z"
    }
   },
   "outputs": [],
   "source": [
    "root = 'https://github.com/mpbrigham/colaboratory-figures/raw/master/nma/autoencoders'\n",
    "filename = 'ae_6h_prelu_bce_adam_25e_32b'\n",
    "url = os.path.join(root, filename)\n",
    "s2 = True\n",
    "\n",
    "if s2:\n",
    "  filename += '_s2'\n",
    "  url += '_s2'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-25T01:06:20.646336Z",
     "iopub.status.busy": "2021-05-25T01:06:20.645799Z",
     "iopub.status.idle": "2021-05-25T01:06:21.168163Z",
     "shell.execute_reply": "2021-05-25T01:06:21.167556Z"
    }
   },
   "outputs": [],
   "source": [
    "model = AutoencoderClass(s2=s2)\n",
    "optimizer = optim.Adam(model.parameters())\n",
    "\n",
    "encoder = model.encoder\n",
    "decoder = model.decoder\n",
    "\n",
    "checkpoint = load_checkpoint(url, filename)\n",
    "\n",
    "model.load_state_dict(checkpoint['model_state_dict'])\n",
    "optimizer.load_state_dict(checkpoint['optimizer_state_dict'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-25T01:06:21.171239Z",
     "iopub.status.busy": "2021-05-25T01:06:21.170672Z",
     "iopub.status.idle": "2021-05-25T01:06:21.173584Z",
     "shell.execute_reply": "2021-05-25T01:06:21.173115Z"
    }
   },
   "outputs": [],
   "source": [
    "# Please uncomment and execute this cell if download of\n",
    "# pre-trained weights fail\n",
    "\n",
    "# model = AutoencoderClass(s2=s2)\n",
    "# encoder = model.encoder\n",
    "# decoder = model.decoder\n",
    "# n_epochs = 10\n",
    "# batch_size = 128\n",
    "# runSGD(model, input_train, input_test,\n",
    "#        n_epochs=n_epochs, batch_size=batch_size)\n",
    "# save_checkpoint(model, optimizer, filename)\n",
    "# checkpoint = load_checkpoint(url, filename)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-25T01:06:21.188855Z",
     "iopub.status.busy": "2021-05-25T01:06:21.188274Z",
     "iopub.status.idle": "2021-05-25T01:06:24.597028Z",
     "shell.execute_reply": "2021-05-25T01:06:24.597517Z"
    }
   },
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "  output_test = model(input_test)\n",
    "  latent_test = encoder(input_test)\n",
    "\n",
    "plot_row([input_test[test_selected_idx], output_test[test_selected_idx]],\n",
    "         image_shape=image_shape)\n",
    "\n",
    "plot_latent_generative(latent_test, y_test, decoder,\n",
    "                       image_shape=image_shape, s2=s2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "# Section 3: Applications of autoencoders"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Application 1 - Image noise\n",
    "Removing noise added to images is often showcased in dimensionality reduction techniques. The tutorial  *W1D5 Dimensionality reduction* illustrated this capability with PCA.\n",
    "\n",
    "We first observe that autoencoders trained with noise-free images output noise-free images when receiving noisy images as input.  However, the reconstructed images will be different from the original images (without noise) since the added noise maps to different coordinates in latent space.\n",
    "\n",
    "The ability to map noise-free and noisy versions to similar regions in latent space is known as *robustness* or *invariance* to noise. How can we build such functionality into the autoencoder? \n",
    "\n",
    "The solution is to train the autoencoder with noise-free and noisy versions mapping to the noise-free version. A faster alternative is to re-train the autoencoder for few epochs with noisy images. These short training sessions fine-tune the weights to map noisy images to their noise-free versions from similar latent space coordinates.\n",
    "\n",
    "Let's start by resetting to the reference state of the autoencoder.\n",
    "\n",
    "**Instructions:**\n",
    "* Please execute the cells below"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-25T01:06:24.601819Z",
     "iopub.status.busy": "2021-05-25T01:06:24.601259Z",
     "iopub.status.idle": "2021-05-25T01:06:24.653913Z",
     "shell.execute_reply": "2021-05-25T01:06:24.654436Z"
    }
   },
   "outputs": [],
   "source": [
    "reset_checkpoint(model, optimizer, checkpoint)\n",
    "\n",
    "with torch.no_grad():\n",
    "  latent_test_ref = encoder(input_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Reconstructions before fine-tuning\n",
    "Let's verify that an autoencoder trained on clean images will output clean images from noisy inputs. We visualize this by plotting three rows:\n",
    "* Top row with noisy images inputs\n",
    "* Middle row with reconstructions of noisy images\n",
    "* Bottom row with reconstructions of the original images (noise-free)\n",
    "\n",
    "![Noise task](https://github.com/mpbrigham/colaboratory-figures/raw/master/nma/autoencoders/applications_noise.png)\n",
    "\n",
    "The bottom row helps identify samples with reconstruction issues before adding noise. This row shows the baseline reconstruction quality for these samples rather than the original images. (Why?)\n",
    "\n",
    "**Instructions:**\n",
    "* Please execute the cell(s) below"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-25T01:06:24.659671Z",
     "iopub.status.busy": "2021-05-25T01:06:24.659136Z",
     "iopub.status.idle": "2021-05-25T01:06:26.923262Z",
     "shell.execute_reply": "2021-05-25T01:06:26.922271Z"
    }
   },
   "outputs": [],
   "source": [
    "noise_factor = 0.4\n",
    "\n",
    "input_train_noisy = (input_train\n",
    "                     + noise_factor * np.random.normal(size=input_train.shape))\n",
    "input_train_noisy = np.clip(input_train_noisy, input_train.min(),\n",
    "                            input_train.max(), dtype=np.float32)\n",
    "\n",
    "input_test_noisy = (input_test\n",
    "                    + noise_factor * np.random.normal(size=input_test.shape))\n",
    "input_test_noisy = np.clip(input_test_noisy, input_test.min(),\n",
    "                           input_test.max(), dtype=np.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-25T01:06:26.927798Z",
     "iopub.status.busy": "2021-05-25T01:06:26.927210Z",
     "iopub.status.idle": "2021-05-25T01:06:28.466010Z",
     "shell.execute_reply": "2021-05-25T01:06:28.466754Z"
    }
   },
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "  output_test_noisy = model(input_test_noisy)\n",
    "  latent_test_noisy = encoder(input_test_noisy)\n",
    "  output_test = model(input_test)\n",
    "\n",
    "plot_row([input_test_noisy[test_selected_idx],\n",
    "          output_test_noisy[test_selected_idx],\n",
    "          output_test[test_selected_idx]], image_shape=image_shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Latent space before fine-tuning\n",
    "We investigate the origin of reconstruction errors by looking at how adding noise to input affects latent space coordinates. The decoder interprets significant coordinate changes as different digits.\n",
    "\n",
    "The function `plot_latent_ab` compares latent space coordinates for the same set of samples between two conditions.  Here, we display coordinates for the ten samples from the previous cell before and after adding noise:\n",
    "* The left plot shows the coordinates of the original samples (noise-free)\n",
    "* The plot on the right shows the new coordinates after adding noise\n",
    "\n",
    "**Instructions:**\n",
    "* Please execute the cell below"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-25T01:06:28.484039Z",
     "iopub.status.busy": "2021-05-25T01:06:28.469966Z",
     "iopub.status.idle": "2021-05-25T01:06:28.881064Z",
     "shell.execute_reply": "2021-05-25T01:06:28.880554Z"
    }
   },
   "outputs": [],
   "source": [
    "plot_latent_ab(latent_test, latent_test_noisy, y_test, test_selected_idx,\n",
    "               title_a='Before noise', title_b='After noise', s2=s2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Fine-tuning the autoencoder with noisy images\n",
    "Let's re-train the autoencoder with noisy images on the input and original (noise-free) images on the output, and regenerate the previous plots.\n",
    "\n",
    "We now see that both noisy and noise-free images match similar locations in latent space. The network denoises the input with a latent-space representation that is more robust to noise.\n",
    "\n",
    "**Instructions:**\n",
    "* Please execute the cell(s) below"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-25T01:06:28.887063Z",
     "iopub.status.busy": "2021-05-25T01:06:28.886523Z",
     "iopub.status.idle": "2021-05-25T01:06:57.597405Z",
     "shell.execute_reply": "2021-05-25T01:06:57.596896Z"
    }
   },
   "outputs": [],
   "source": [
    "n_epochs = 3\n",
    "batch_size = 32\n",
    "\n",
    "model.train()\n",
    "\n",
    "runSGD(model, input_train_noisy, input_test_noisy,\n",
    "       out_train=input_train, out_test=input_test,\n",
    "       n_epochs=n_epochs, batch_size=batch_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-25T01:06:57.608731Z",
     "iopub.status.busy": "2021-05-25T01:06:57.608153Z",
     "iopub.status.idle": "2021-05-25T01:06:59.623271Z",
     "shell.execute_reply": "2021-05-25T01:06:59.622186Z"
    }
   },
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "  output_test_noisy = model(input_test_noisy)\n",
    "  latent_test_noisy = encoder(input_test_noisy)\n",
    "  output_test = model(input_test)\n",
    "\n",
    "plot_row([input_test_noisy[test_selected_idx],\n",
    "          output_test_noisy[test_selected_idx],\n",
    "          output_test[test_selected_idx]], image_shape=image_shape)\n",
    "\n",
    "plot_latent_ab(latent_test, latent_test_noisy, y_test, test_selected_idx,\n",
    "               title_a='Before fine-tuning',\n",
    "               title_b='After fine-tuning', s2=s2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Global latent space shift\n",
    "The new latent space representation is more robust to noise and may result in a better internal representation of the dataset. We verify this by inspecting the latent space with clean images before and after fine-tuning with noisy images.\n",
    "\n",
    "Fine-tuning the network with noisy images causes a *domain shift* in the dataset, i.e., a change in the distribution of images since the dataset was initially composed of noise-free images. Depending on the task and the extent of changes during re-train,  (number of epochs, optimizer characteristics, etc.), the new latent space representation may become less well adapted to the original data as a side-effect. How could we address *domain shift* and improve both noisy and noise-free images?\n",
    "\n",
    "**Instructions:**\n",
    "* Please execute the cell(s) below"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-25T01:06:59.640809Z",
     "iopub.status.busy": "2021-05-25T01:06:59.640233Z",
     "iopub.status.idle": "2021-05-25T01:07:03.425915Z",
     "shell.execute_reply": "2021-05-25T01:07:03.426399Z"
    }
   },
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "  latent_test = encoder(input_test)\n",
    "\n",
    "plot_latent_ab(latent_test_ref, latent_test, y_test, test_subset_idx,\n",
    "               title_a='Before fine-tuning',\n",
    "               title_b='After fine-tuning', s2=s2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Application 2 - Image occlusion\n",
    "We now investigate the effects of image occlusion. Drawing from the previous exercise, we expect the autoencoder to reconstruct complete images since the train set does not contain occluded images (right?).\n",
    "\n",
    "We visualize this by plotting three rows:\n",
    "* Top row with occluded images\n",
    "* Middle row with reconstructions of occluded images\n",
    "* Bottom row with reconstructions of the original images\n",
    "\n",
    "![Occlusion task](https://github.com/mpbrigham/colaboratory-figures/raw/master/nma/autoencoders/applications_occlusion.png)\n",
    "\n",
    "Similarly, we investigate the source of this issue by looking at the representation of partial images in latent space and how it adjusts after fine-tuning.\n",
    "\n",
    "**Instructions:**\n",
    "* Please execute the cell(s) below"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-25T01:07:03.430990Z",
     "iopub.status.busy": "2021-05-25T01:07:03.430440Z",
     "iopub.status.idle": "2021-05-25T01:07:03.479033Z",
     "shell.execute_reply": "2021-05-25T01:07:03.478516Z"
    }
   },
   "outputs": [],
   "source": [
    "reset_checkpoint(model, optimizer, checkpoint)\n",
    "\n",
    "with torch.no_grad():\n",
    "  latent_test_ref = encoder(input_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Before fine-tuning\n",
    "\n",
    "**Instructions:**\n",
    "* Please execute the cell(s) below"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-25T01:07:03.482753Z",
     "iopub.status.busy": "2021-05-25T01:07:03.482245Z",
     "iopub.status.idle": "2021-05-25T01:07:03.657488Z",
     "shell.execute_reply": "2021-05-25T01:07:03.656604Z"
    }
   },
   "outputs": [],
   "source": [
    "input_train_mask = image_occlusion(input_train, image_shape=image_shape)\n",
    "input_test_mask = image_occlusion(input_test, image_shape=image_shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-25T01:07:03.662357Z",
     "iopub.status.busy": "2021-05-25T01:07:03.661824Z",
     "iopub.status.idle": "2021-05-25T01:07:05.637316Z",
     "shell.execute_reply": "2021-05-25T01:07:05.636806Z"
    }
   },
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "  output_test_mask = model(input_test_mask)\n",
    "  latent_test_mask = encoder(input_test_mask)\n",
    "  output_test = model(input_test)\n",
    "\n",
    "plot_row([input_test_mask[test_selected_idx],\n",
    "          output_test_mask[test_selected_idx],\n",
    "          output_test[test_selected_idx]], image_shape=image_shape)\n",
    "\n",
    "plot_latent_ab(latent_test, latent_test_mask, y_test, test_selected_idx,\n",
    "               title_a='Before occlusion', title_b='After occlusion', s2=s2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### After fine-tuning"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-25T01:07:05.649092Z",
     "iopub.status.busy": "2021-05-25T01:07:05.642420Z",
     "iopub.status.idle": "2021-05-25T01:07:38.029960Z",
     "shell.execute_reply": "2021-05-25T01:07:38.029318Z"
    }
   },
   "outputs": [],
   "source": [
    "n_epochs = 3\n",
    "batch_size = 32\n",
    "\n",
    "model.train()\n",
    "\n",
    "runSGD(model, input_train_mask, input_test_mask,\n",
    "       out_train=input_train, out_test=input_test,\n",
    "       n_epochs=n_epochs, batch_size=batch_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-25T01:07:38.040669Z",
     "iopub.status.busy": "2021-05-25T01:07:38.040088Z",
     "iopub.status.idle": "2021-05-25T01:07:40.073961Z",
     "shell.execute_reply": "2021-05-25T01:07:40.073210Z"
    }
   },
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "  output_test_mask = model(input_test_mask)\n",
    "  latent_test_mask = encoder(input_test_mask)\n",
    "  output_test = model(input_test)\n",
    "\n",
    "plot_row([input_test_mask[test_selected_idx],\n",
    "          output_test_mask[test_selected_idx],\n",
    "          output_test[test_selected_idx]], image_shape=image_shape)\n",
    "\n",
    "plot_latent_ab(latent_test, latent_test_mask, y_test, test_selected_idx,\n",
    "               title_a='Before fine-tuning',\n",
    "               title_b='After fine-tuning', s2=s2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-25T01:07:40.084932Z",
     "iopub.status.busy": "2021-05-25T01:07:40.084374Z",
     "iopub.status.idle": "2021-05-25T01:07:43.908411Z",
     "shell.execute_reply": "2021-05-25T01:07:43.909013Z"
    }
   },
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "  latent_test = encoder(input_test)\n",
    "\n",
    "plot_latent_ab(latent_test_ref, latent_test, y_test, test_subset_idx,\n",
    "               title_a='Before fine-tuning',\n",
    "               title_b='After fine-tuning', s2=s2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Application 3 - Image rotation\n",
    "Finally, we look at the effect of image rotation in latent space coordinates. This task is arguably more challenging since it may require a complete re-write of image reconstruction.\n",
    "\n",
    "We visualize this by plotting three rows:\n",
    "* Top row with rotated images\n",
    "* Middle row with reconstructions of rotated images\n",
    "* Bottom row with reconstructions of the original images\n",
    "\n",
    "![Rotation task](https://github.com/mpbrigham/colaboratory-figures/raw/master/nma/autoencoders/applications_rotation.png)\n",
    "\n",
    "We investigate the source of this issue by looking at the representation of rotated images in latent space and how it adjusts after fine-tuning.\n",
    "\n",
    "**Instructions:**\n",
    "* Please execute the cell(s) below"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-25T01:07:43.923984Z",
     "iopub.status.busy": "2021-05-25T01:07:43.915181Z",
     "iopub.status.idle": "2021-05-25T01:07:43.959170Z",
     "shell.execute_reply": "2021-05-25T01:07:43.958687Z"
    }
   },
   "outputs": [],
   "source": [
    "reset_checkpoint(model, optimizer, checkpoint)\n",
    "\n",
    "with torch.no_grad():\n",
    "  latent_test_ref = encoder(input_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Before fine-tuning\n",
    "\n",
    "**Instructions:**\n",
    "* Please execute the cell(s) below"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-25T01:07:43.962908Z",
     "iopub.status.busy": "2021-05-25T01:07:43.962384Z",
     "iopub.status.idle": "2021-05-25T01:07:54.763893Z",
     "shell.execute_reply": "2021-05-25T01:07:54.762976Z"
    }
   },
   "outputs": [],
   "source": [
    "input_train_rotation = image_rotation(input_train, 90, image_shape=image_shape)\n",
    "input_test_rotation = image_rotation(input_test, 90, image_shape=image_shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-25T01:07:54.776762Z",
     "iopub.status.busy": "2021-05-25T01:07:54.776158Z",
     "iopub.status.idle": "2021-05-25T01:07:56.698560Z",
     "shell.execute_reply": "2021-05-25T01:07:56.699009Z"
    }
   },
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "  output_test_rotation = model(input_test_rotation)\n",
    "  latent_test_rotation = encoder(input_test_rotation)\n",
    "  output_test = model(input_test)\n",
    "\n",
    "plot_row([input_test_rotation[test_selected_idx],\n",
    "          output_test_rotation[test_selected_idx],\n",
    "          output_test[test_selected_idx]], image_shape=image_shape)\n",
    "\n",
    "plot_latent_ab(latent_test, latent_test_rotation, y_test, test_selected_idx,\n",
    "               title_a='Before rotation', title_b='After rotation', s2=s2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### After fine-tuning\n",
    "\n",
    "**Instructions:**\n",
    "* Please execute the cell(s) below"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-25T01:07:56.705202Z",
     "iopub.status.busy": "2021-05-25T01:07:56.704654Z",
     "iopub.status.idle": "2021-05-25T01:08:46.886859Z",
     "shell.execute_reply": "2021-05-25T01:08:46.885991Z"
    }
   },
   "outputs": [],
   "source": [
    "n_epochs = 5\n",
    "batch_size = 32\n",
    "\n",
    "model.train()\n",
    "\n",
    "runSGD(model, input_train_rotation, input_test_rotation,\n",
    "       out_train=input_train, out_test=input_test,\n",
    "       n_epochs=n_epochs, batch_size=batch_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-25T01:08:46.900833Z",
     "iopub.status.busy": "2021-05-25T01:08:46.900256Z",
     "iopub.status.idle": "2021-05-25T01:08:48.884435Z",
     "shell.execute_reply": "2021-05-25T01:08:48.883903Z"
    }
   },
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "  output_test_rotation = model(input_test_rotation)\n",
    "  latent_test_rotation = encoder(input_test_rotation)\n",
    "  output_test = model(input_test)\n",
    "\n",
    "plot_row([input_test_rotation[test_selected_idx],\n",
    "          output_test_rotation[test_selected_idx],\n",
    "          output_test[test_selected_idx]], image_shape=image_shape)\n",
    "\n",
    "plot_latent_ab(latent_test, latent_test_rotation, y_test, test_selected_idx,\n",
    "               title_a='Before fine-tuning',\n",
    "               title_b='After fine-tuning', s2=s2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-25T01:08:48.896818Z",
     "iopub.status.busy": "2021-05-25T01:08:48.896282Z",
     "iopub.status.idle": "2021-05-25T01:08:52.714178Z",
     "shell.execute_reply": "2021-05-25T01:08:52.714593Z"
    }
   },
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "  latent_test = encoder(input_test)\n",
    "\n",
    "plot_latent_ab(latent_test_ref, latent_test, y_test, test_subset_idx,\n",
    "               title_a='Before fine-tuning',\n",
    "               title_b='After fine-tuning', s2=s2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Application 4 - What would digit \"6\" look like if we had never seen it before?\n",
    "Before we start melting our brains with such an impossible task, let's just ask the autoencoder to do it!\n",
    "\n",
    "We train the autoencoder from scratch without digit class `6` and visualize reconstructions from digit `6`.\n",
    "\n",
    "**Instructions:**\n",
    "* Please execute the cell(s) below"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-25T01:08:52.719051Z",
     "iopub.status.busy": "2021-05-25T01:08:52.718441Z",
     "iopub.status.idle": "2021-05-25T01:08:52.727286Z",
     "shell.execute_reply": "2021-05-25T01:08:52.726570Z"
    }
   },
   "outputs": [],
   "source": [
    "model = AutoencoderClass(s2=s2)\n",
    "optimizer = optim.Adam(model.parameters())\n",
    "\n",
    "encoder = model.encoder\n",
    "decoder = model.decoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-25T01:08:52.740977Z",
     "iopub.status.busy": "2021-05-25T01:08:52.740371Z",
     "iopub.status.idle": "2021-05-25T01:08:52.773200Z",
     "shell.execute_reply": "2021-05-25T01:08:52.774034Z"
    }
   },
   "outputs": [],
   "source": [
    "missing = 6\n",
    "\n",
    "my_input_train = input_train[y_train != missing]\n",
    "my_input_test = input_test[y_test != missing]\n",
    "my_y_test = y_test[y_test != missing]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-25T01:08:52.779763Z",
     "iopub.status.busy": "2021-05-25T01:08:52.778306Z",
     "iopub.status.idle": "2021-05-25T01:09:22.071448Z",
     "shell.execute_reply": "2021-05-25T01:09:22.070460Z"
    }
   },
   "outputs": [],
   "source": [
    "n_epochs = 3\n",
    "batch_size = 32\n",
    "\n",
    "runSGD(model, my_input_train, my_input_test,\n",
    "       n_epochs=n_epochs, batch_size=batch_size)\n",
    "\n",
    "with torch.no_grad():\n",
    "  output_test = model(input_test)\n",
    "  my_latent_test = encoder(my_input_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-25T01:09:22.075604Z",
     "iopub.status.busy": "2021-05-25T01:09:22.075019Z",
     "iopub.status.idle": "2021-05-25T01:09:25.141542Z",
     "shell.execute_reply": "2021-05-25T01:09:25.142041Z"
    }
   },
   "outputs": [],
   "source": [
    "plot_row([input_test[y_test == 6], output_test[y_test == 6]],\n",
    "         image_shape=image_shape)\n",
    "\n",
    "plot_latent_generative(my_latent_test, my_y_test, decoder,\n",
    "                       image_shape=image_shape, s2=s2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Exercise 1: Removing the most dominant digit classes\n",
    "Digit classes `0` and `1` are dominant in the sense that these occupy large areas of the decoder grid, compared to other digit classes that occupy very little generative space.\n",
    "\n",
    "How will latent space change when removing the two most dominant digit classes? Will latent space re-distribute evenly among remaining classes or choose another two dominant classes?\n",
    "\n",
    "**Instructions:**\n",
    "* Please execute the cell(s) below\n",
    "* The intersection of two boolean arrays by condition is specified as `x[(cond_a)&(cond_b)]`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-25T01:09:25.146419Z",
     "iopub.status.busy": "2021-05-25T01:09:25.145857Z",
     "iopub.status.idle": "2021-05-25T01:09:25.154549Z",
     "shell.execute_reply": "2021-05-25T01:09:25.154065Z"
    }
   },
   "outputs": [],
   "source": [
    "model = AutoencoderClass(s2=s2)\n",
    "optimizer = optim.Adam(model.parameters())\n",
    "\n",
    "encoder = model.encoder\n",
    "decoder = model.decoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-25T01:09:25.159541Z",
     "iopub.status.busy": "2021-05-25T01:09:25.158069Z",
     "iopub.status.idle": "2021-05-25T01:09:25.160177Z",
     "shell.execute_reply": "2021-05-25T01:09:25.160633Z"
    }
   },
   "outputs": [],
   "source": [
    "missing_a = 1\n",
    "missing_b = 0\n",
    "#################################################\n",
    "## TODO for students:\n",
    "#################################################\n",
    "# input train data\n",
    "# my_input_train = ...\n",
    "# input test data\n",
    "# my_input_test = ...\n",
    "# model\n",
    "# my_y_test = ...\n",
    "\n",
    "# Uncomment to test your code\n",
    "# print(my_input_train.shape)\n",
    "# print(my_input_test.shape)\n",
    "# print(my_y_test.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**SAMPLE OUTPUT**\n",
    "\n",
    "```\n",
    "torch.Size([47335, 784])\n",
    "torch.Size([7885, 784])\n",
    "torch.Size([7885])\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-25T01:09:25.173307Z",
     "iopub.status.busy": "2021-05-25T01:09:25.172689Z",
     "iopub.status.idle": "2021-05-25T01:09:25.205691Z",
     "shell.execute_reply": "2021-05-25T01:09:25.206169Z"
    }
   },
   "outputs": [],
   "source": [
    "# to_remove solution\n",
    "missing_a = 1\n",
    "missing_b = 0\n",
    "# input train data\n",
    "my_input_train = input_train[(y_train != missing_a) & (y_train != missing_b)]\n",
    "# input test data\n",
    "my_input_test = input_test[(y_test != missing_a) & (y_test != missing_b)]\n",
    "# model\n",
    "my_y_test = y_test[(y_test != missing_a) & (y_test != missing_b)]\n",
    "\n",
    "# Uncomment to test your code\n",
    "print(my_input_train.shape)\n",
    "print(my_input_test.shape)\n",
    "print(my_y_test.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-25T01:09:25.212280Z",
     "iopub.status.busy": "2021-05-25T01:09:25.210487Z",
     "iopub.status.idle": "2021-05-25T01:09:50.000979Z",
     "shell.execute_reply": "2021-05-25T01:09:50.000459Z"
    }
   },
   "outputs": [],
   "source": [
    "n_epochs = 3\n",
    "batch_size = 32\n",
    "\n",
    "runSGD(model, my_input_train, my_input_test,\n",
    "       n_epochs=n_epochs, batch_size=batch_size)\n",
    "\n",
    "with torch.no_grad():\n",
    "  output_test = model(input_test)\n",
    "  my_latent_test = encoder(my_input_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-25T01:09:50.005909Z",
     "iopub.status.busy": "2021-05-25T01:09:50.005343Z",
     "iopub.status.idle": "2021-05-25T01:09:54.003239Z",
     "shell.execute_reply": "2021-05-25T01:09:54.003728Z"
    }
   },
   "outputs": [],
   "source": [
    "plot_row([input_test[y_test == missing_a], output_test[y_test == missing_a]],\n",
    "         image_shape=image_shape)\n",
    "\n",
    "plot_row([input_test[y_test == missing_b], output_test[y_test == missing_b]],\n",
    "         image_shape=image_shape)\n",
    "\n",
    "plot_latent_generative(my_latent_test, my_y_test, decoder,\n",
    "                       image_shape=image_shape, s2=s2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "# Section 4: ANNs? Same but different!\n",
    "\"Same same but different\" is an expression used in some parts of Asia to express differences between supposedly similar subjects. In this exercise, we investigate a fundamental difference in how fully-connected ANNs process visual information compared to human vision.\n",
    "\n",
    "The previous exercises showed ANN autoencoder performing cognitive tasks with relative ease. However, there is a crucial aspect of ANN processing already encoded in the vectorization of images. This network architecture completely ignores the relative position of pixels. To illustrate this, we show that learning proceeds just as well with shuffled pixel locations.\n",
    "\n",
    "First, we obtain a reversible shuffle map stored in `shuffle_image_idx` used to shuffle image pixels randomly.\n",
    "\n",
    "&nbsp;\n",
    "\n",
    "![mnist_pixel_shuffle](https://github.com/mpbrigham/colaboratory-figures/raw/master/nma/autoencoders/mnist_pixel_shuffle.png)\n",
    "\n",
    "&nbsp;\n",
    "\n",
    "The unshuffled image set `input_shuffle` is recovered as follows:\n",
    "```\n",
    "input_shuffle[:, shuffle_rev_image_idx]]\n",
    "```\n",
    "\n",
    "First, we set up the reversible shuffle map and visualize a few images with shuffled and unshuffled pixels, followed by their noisy versions.\n",
    "\n",
    "**Instructions:**\n",
    "* Please execute the cell(s) below"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-25T01:09:54.009812Z",
     "iopub.status.busy": "2021-05-25T01:09:54.009226Z",
     "iopub.status.idle": "2021-05-25T01:09:55.524694Z",
     "shell.execute_reply": "2021-05-25T01:09:55.524209Z"
    }
   },
   "outputs": [],
   "source": [
    "# create forward and reverse indexes for pixel shuffling\n",
    "shuffle_image_idx = np.arange(input_size)\n",
    "shuffle_rev_image_idx = np.empty_like(shuffle_image_idx)\n",
    "\n",
    "# shuffle pixel location\n",
    "np.random.shuffle(shuffle_image_idx)\n",
    "\n",
    "# store reverse locations\n",
    "for pos_idx, pos in enumerate(shuffle_image_idx):\n",
    "  shuffle_rev_image_idx[pos] = pos_idx\n",
    "\n",
    "# shuffle train and test sets\n",
    "input_train_shuffle = input_train[:, shuffle_image_idx]\n",
    "input_test_shuffle = input_test[:, shuffle_image_idx]\n",
    "\n",
    "input_train_shuffle_noisy = input_train_noisy[:, shuffle_image_idx]\n",
    "input_test_shuffle_noisy = input_test_noisy[:, shuffle_image_idx]\n",
    "\n",
    "# show samples with shuffled pixels\n",
    "plot_row([input_test_shuffle,\n",
    "          input_test_shuffle[:, shuffle_rev_image_idx]],\n",
    "         image_shape=image_shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-25T01:09:55.536440Z",
     "iopub.status.busy": "2021-05-25T01:09:55.535887Z",
     "iopub.status.idle": "2021-05-25T01:09:56.561778Z",
     "shell.execute_reply": "2021-05-25T01:09:56.561273Z"
    }
   },
   "outputs": [],
   "source": [
    "# show noisy samples with shuffled pixels\n",
    "plot_row([input_train_shuffle_noisy[test_selected_idx],\n",
    "          input_train_shuffle_noisy[:, shuffle_rev_image_idx][test_selected_idx]],\n",
    "         image_shape=image_shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We initialize and train the network in the denoising task with shuffled pixels.\n",
    "\n",
    "**Instructions:**\n",
    "* Please execute the cell below"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-25T01:09:56.566202Z",
     "iopub.status.busy": "2021-05-25T01:09:56.565684Z",
     "iopub.status.idle": "2021-05-25T01:10:26.282781Z",
     "shell.execute_reply": "2021-05-25T01:10:26.283229Z"
    }
   },
   "outputs": [],
   "source": [
    "model = AutoencoderClass(s2=s2)\n",
    "\n",
    "encoder = model.encoder\n",
    "decoder = model.decoder\n",
    "\n",
    "n_epochs = 3\n",
    "batch_size = 32\n",
    "\n",
    "# train the model to denoise shuffled images\n",
    "runSGD(model, input_train_shuffle_noisy, input_test_shuffle_noisy,\n",
    "       out_train=input_train_shuffle, out_test=input_test_shuffle,\n",
    "       n_epochs=n_epochs, batch_size=batch_size)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, visualize reconstructions and latent space representation with the trained model.\n",
    "\n",
    "We visualize reconstructions by plotting three rows:\n",
    "* Top row with shuffled noisy images\n",
    "* Middle row with reconstructions of shuffled denoised images\n",
    "* Bottom row with unshuffled reconstructions of denoised images\n",
    "\n",
    "![mnist_pixel_shuffle denoised](https://github.com/mpbrigham/colaboratory-figures/raw/master/nma/autoencoders/applications_ann_denoise.png)\n",
    "\n",
    "We obtain the same organization in the encoder map as before. Sharing similar internal representations confirms the network to ignore the relative position of pixels. The decoder grid is different than before since it generates shuffled images.\n",
    "\n",
    "**Instructions:**\n",
    "* Please execute the cell below"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-25T01:10:26.289665Z",
     "iopub.status.busy": "2021-05-25T01:10:26.289091Z",
     "iopub.status.idle": "2021-05-25T01:10:29.963718Z",
     "shell.execute_reply": "2021-05-25T01:10:29.964208Z"
    }
   },
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "  latent_test_shuffle_noisy = encoder(input_test_shuffle_noisy)\n",
    "  output_test_shuffle_noisy = model(input_test_shuffle_noisy)\n",
    "\n",
    "plot_row([input_test_shuffle_noisy[test_selected_idx],\n",
    "          output_test_shuffle_noisy[test_selected_idx],\n",
    "          output_test_shuffle_noisy[:, shuffle_rev_image_idx][test_selected_idx]],\n",
    "         image_shape=image_shape)\n",
    "\n",
    "plot_latent_generative(latent_test_shuffle_noisy, y_test, decoder,\n",
    "                       image_shape=image_shape, s2=s2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "# Summary\n",
    "Hooray! You have finished the last Tutorial of NMA 2020!\n",
    "\n",
    "We hope you've enjoyed these tutorials and learned about the usefulness of autoencoders to model rich and non-linear representations of data. We hope you may find them useful in your research, perhaps to model certain aspects of cognition or even extend them to biologically plausible architectures - autoencoders of spiking neurons, anyone?\n",
    "\n",
    "These are the key take away messages from these tutorials:\n",
    "\n",
    "**Autoencoders trained in *learning by doing* tasks such as compression/decompression, removing noise, etc. can uncover rich lower-dimensional structure embedded in structured images and other cognitively relevant data.**\n",
    "\n",
    "**The data domain seen during training imprints a \"cognitive bias\" - you only see what you expect to see, which can only be similar to what you saw before.**\n",
    "\n",
    "Such bias is related to the concept [*What you see is all there is*](https://en.wikipedia.org/wiki/Thinking,_Fast_and_Slow) coined by Daniel Kahneman in psychology.\n",
    "\n",
    "For additional applications of autoencoders to neuroscience, check the spike sorting application in the outro video, and also see [here](https://www.nature.com/articles/s41592-018-0109-9) how to replicate the input-output relationship of real networks of neurons with autoencoders."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "execution": {
     "iopub.execute_input": "2021-05-25T01:10:29.971107Z",
     "iopub.status.busy": "2021-05-25T01:10:29.969658Z",
     "iopub.status.idle": "2021-05-25T01:10:30.022524Z",
     "shell.execute_reply": "2021-05-25T01:10:30.022021Z"
    }
   },
   "outputs": [],
   "source": [
    "# @title Video 2: Wrap-up\n",
    "from IPython.display import YouTubeVideo\n",
    "video = YouTubeVideo(id=\"ziiZK9P6AXQ\", width=854, height=480, fs=1)\n",
    "print(\"Video available at https://youtube.com/watch?v=\" + video.id)\n",
    "video"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [
    "JgF7_zvb8d0C"
   ],
   "include_colab_link": true,
   "name": "Tutorial3",
   "provenance": [],
   "toc_visible": true
  },
  "kernel": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
